{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Is0dpglTAL_Q"
      },
      "source": [
        "# Automated Mask Detection and Annotation Generation for Image Segmentation models"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SP7jOoUcBqQM"
      },
      "source": [
        "**Objective** : Given an image populated with objects of a single category, this Colab notebook endeavors to generate a COCO-formatted JSON annotation file corresponding to the image. The derived JSON file serves as a precursor to developing a dataset in the TF Records format, which in turn, is instrumental for training image segmentation models."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IHjfsY2SBq3G"
      },
      "source": [
        "**Background**: The inception of this methodology is rooted in the scarcity of manually annotated data. By procuring images featuring multiple objects of a singular category, we can transition towards an automated annotation paradigm, substantially curtailing the expenses associated with manual annotation."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AmeXv7QgAWyD"
      },
      "source": [
        "In this Colab notebook, we will employ the \"Segment Anything\" model from Facebook to identify masks for all objects within the given image. Subsequent to the mask detection, a post-processing step will be carried out to filter out unrelated masks, retaining only those pertinent to the target objects.\n",
        "\n",
        "Following the mask refinement, we will leverage the Imantics library to transition the binary masks into a COCO JSON annotation format. This conversion facilitates the representation of multiple objects within a structured annotation file, paving the way for further analysis and utilization in image segmentation tasks.\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-6tUoca-he89"
      },
      "source": [
        "The resulted COCO JSON file from this colab can be merged with other COCO JSON files using this [notebook](https://github.com/tensorflow/models/blob/master/official/projects/waste_identification_ml/pre_processing/merge_coco_files_faster.ipynb) from our project. The final merged COCO JSON file along with the corresponding images can then be converted to TFRecod format using this [notebook](https://github.com/tensorflow/models/blob/master/official/projects/waste_identification_ml/pre_processing/coco_to_tfrecord.ipynb)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y9DaFu1_HzSr"
      },
      "source": [
        "## Importing and Installing Required Libraries\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "GBhN1q8wG-F-",
        "outputId": "80e08e44-59cf-43cf-8d9e-3fc5be5a4252"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for segment-anything (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for imantics (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ]
        }
      ],
      "source": [
        "# This command installs the 'Segment Anything' library directly from its GitHub\n",
        "# repository.\n",
        "# 'Segment Anything' is a project by Facebook Research, which provides tools for\n",
        "# object segmentation.\n",
        "!pip install -q git+https://github.com/facebookresearch/segment-anything.git\n",
        "!pip install -q imantics"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qcVv_9av_4Ko"
      },
      "outputs": [],
      "source": [
        "!git clone --depth 1 https://github.com/tensorflow/models 2\u003e/dev/null"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QqJXuaKkHRja"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import cv2\n",
        "import random\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "from segment_anything import sam_model_registry\n",
        "from segment_anything import SamAutomaticMaskGenerator, SamPredictor\n",
        "from typing import Any\n",
        "import sys\n",
        "\n",
        "sys.path.append('models/official/projects/waste_identification_ml/data_generation/')\n",
        "import utils\n",
        "\n",
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "71j1v6HRHcpC",
        "outputId": "a6380317-affe-4fe7-dd5b-0ee50fc6a30b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "CUDA is available: True\n"
          ]
        }
      ],
      "source": [
        "# This line of code checks and prints whether CUDA is available on this machine.\n",
        "# CUDA is a parallel computing platform and application programming interface\n",
        "# model created by NVIDIA. It allows developers to use CUDA-enabled graphics\n",
        "# processing units (GPUs) for general purpose processing.\n",
        "print(\"CUDA is available:\", torch.cuda.is_available())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VF0HStCiM4iA"
      },
      "source": [
        "## Load Segment Anything model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gH1Po7fIH8QN"
      },
      "outputs": [],
      "source": [
        "# This command quietly downloads the pre-trained model checkpoint file\n",
        "# 'sam_vit_h_4b8939.pth' for the ViT-H SAM model from Segment Anything project.\n",
        "!wget -q \\\n",
        "'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X5ikRJqGM79p"
      },
      "outputs": [],
      "source": [
        "# Specify the path to the pre-trained model checkpoint\n",
        "sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n",
        "\n",
        "# Define the model type to be used\n",
        "model_type = \"vit_h\"\n",
        "\n",
        "# Specify the device to be used for model deployment; 'cuda' implies that a GPU\n",
        "# is being used\n",
        "device = \"cuda\"\n",
        "\n",
        "# Load the pre-trained model using the specified checkpoint and model type from\n",
        "# the SAM model registry\n",
        "sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n",
        "\n",
        "# Transfer the model to the specified device\n",
        "sam.to(device=device)\n",
        "\n",
        "# Instantiate an automatic mask generator using the loaded SAM model\n",
        "mask_generator = SamAutomaticMaskGenerator(sam)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BcLM-iFYOEkv"
      },
      "source": [
        "## Inferencing"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4awsjpHR0QXI"
      },
      "outputs": [],
      "source": [
        "# download the sample image from the circularnet project\n",
        "url = (\n",
        "    \"https://raw.githubusercontent.com/tensorflow/models/master/official/\"\n",
        "    \"projects/waste_identification_ml/pre_processing/config/sample_images/\"\n",
        "    \"image_4.png\"\n",
        ")\n",
        "\n",
        "!curl -O {url} \u003e /dev/null 2\u003e\u00261"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7vdYlvmuN6rF"
      },
      "outputs": [],
      "source": [
        "# Reading an image file.\n",
        "original_image = cv2.imread('image_4.png')\n",
        "image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)\n",
        "image = cv2.resize(image, (1024,512), interpolation = cv2.INTER_AREA)\n",
        "utils.plot_image(image)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "urvv8VNbu4a1"
      },
      "source": [
        "To generate masks, just run `generate` on an image.\n",
        "\n",
        "Mask generation returns a list over masks, where each mask is a dictionary containing various data about the mask. These keys are:\n",
        "\n",
        "*   `segmentation` : the mask\n",
        "*   `area` : the area of the mask in pixels\n",
        "*   `bbox` : the boundary box of the mask in XYWH format\n",
        "*   `predicted_iou` : the model's own prediction for the quality of the mask\n",
        "*   `point_coords` : the sampled input point that generated this mask\n",
        "*   `stability_score` : an additional measure of mask quality\n",
        "*   `crop_box` : the crop of the image used to generate this mask in XYWH format\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ca3OVfEVOoAa"
      },
      "outputs": [],
      "source": [
        "result = mask_generator.generate(image)\n",
        "print(\"Total number of masks found:\", len(result))\n",
        "utils.display_image_with_annotations(image, result)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0Nl0HT-lQ0aP"
      },
      "outputs": [],
      "source": [
        "# Display all the detected masks.\n",
        "utils.plot_grid(result, n_cols=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FKwArN2rfaJF"
      },
      "source": [
        "In the output you can see that the model detected many masks which do not represent the object of interest and represent a part of the background. Many overlapping masks were also detected which belong to the same objects."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6dHqnyaZ_EA9"
      },
      "source": [
        "## Convert bbox format"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CJnKCejY_F3b"
      },
      "outputs": [],
      "source": [
        "# converting bbox format from XYWH to xmin, ymin, xmax, ymax\n",
        "for element in result:\n",
        "  element['bbox'] = utils.convert_bbox_format(element['bbox'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HzQXwPjweXN4"
      },
      "source": [
        "## Mask filtering."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1dYG6oWYgrQL"
      },
      "source": [
        "The code snippet is analyzing a list of masks (**filtered_masks**) to find pairs of masks that are nested or similar based on a nesting score. If a pair of masks has a nesting score greater than **0.95**, it identifies the mask with the bigger area. An object may have multiple similar masks which are either completely nested into each other or not. Our goal is to keep only the mask of an object which has the maximum area. Each object should only have one mask."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YzLXiNUKyf2S"
      },
      "outputs": [],
      "source": [
        "if result:\n",
        "  filtered_unnested_results = utils.filter_nested_similar_masks(result)\n",
        "  print(\"Total number of filtered masks found:\", len(filtered_unnested_results))\n",
        "\n",
        "  utils.display_image_with_annotations(image, filtered_unnested_results)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GKNvx_Rfyy9O"
      },
      "outputs": [],
      "source": [
        "# Display all the masks after the previous filtering step.\n",
        "utils.plot_grid(filtered_unnested_results, n_cols=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XCXjXNwQd5B1"
      },
      "source": [
        "we can see the model detected too many masks which does not represents objects. We will now filter out the masks according to the aspect ration of the bounding boxes and the area of the masks. Masks which are too long, too big or too small will be filtered out in the process below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H2dq41zSS7Mf"
      },
      "outputs": [],
      "source": [
        "UPPER_MULTIPLIER = 6\n",
        "LOWER_MULTIPLIER = 2\n",
        "AREA_FILTER_THRESH = 0.15"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xGWYChGzVWme"
      },
      "outputs": [],
      "source": [
        "if filtered_unnested_results:\n",
        "  filtered_masks = utils.filter_masks(image, filtered_unnested_results, UPPER_MULTIPLIER, LOWER_MULTIPLIER, AREA_FILTER_THRESH)\n",
        "  print(\"Total number of filtered masks found:\", len(filtered_masks))\n",
        "\n",
        "  utils.display_image_with_annotations(image, filtered_masks)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ooU5ELnIYNiD"
      },
      "outputs": [],
      "source": [
        "# Display all the masks after the previous filtering step.\n",
        "utils.plot_grid(filtered_masks, n_cols=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XIchLh0BGjc7"
      },
      "source": [
        "# Conversion to COCO JSON format"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lUNKHxzZGnpK"
      },
      "outputs": [],
      "source": [
        "# Converting the masks to the required input format.\n",
        "if filtered_masks:\n",
        "  final_masks = []\n",
        "  for i in filtered_masks:\n",
        "    mask_uint8 = i['segmentation'].astype(np.uint8) * 255\n",
        "    resized_mask_uint8 = cv2.resize(\n",
        "        mask_uint8,\n",
        "        (original_image.shape[1], original_image.shape[0]),\n",
        "        interpolation=cv2.INTER_NEAREST\n",
        "    )\n",
        "    final_masks.append(resized_mask_uint8)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OrhqGaqXfLZi"
      },
      "outputs": [],
      "source": [
        "# Assign the category name for all objects in the required COCO JSON file.\n",
        "category_name = 'Plastics_PP'\n",
        "\n",
        "# Desired name of an image in COCO JSON file.\n",
        "image_name = 'xyz.png'\n",
        "\n",
        "coco_json_file = utils.generate_coco_json(\n",
        "    final_masks,\n",
        "    original_image,\n",
        "    category_name,\n",
        "    image_name\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "h3NrZyNMe8oT",
        "outputId": "ae1dabc3-6e49-4aa1-87d9-6eaac741bd34"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[{'id': 0,\n",
              "  'width': 1920,\n",
              "  'height': 1080,\n",
              "  'file_name': 'xyz.png',\n",
              "  'path': '',\n",
              "  'license': None,\n",
              "  'fickr_url': None,\n",
              "  'coco_url': None,\n",
              "  'date_captured': None,\n",
              "  'metadata': {}}]"
            ]
          },
          "execution_count": 19,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Display the information abvout an image.\n",
        "coco_json_file['images']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "HuV9D2bYVj3c",
        "outputId": "ec290c97-eed1-4a76-b8c5-93527c3085cd"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[{'id': 1,\n",
              "  'name': 'Plastics_PP',\n",
              "  'supercategory': None,\n",
              "  'metadata': {},\n",
              "  'color': '#98f270'}]"
            ]
          },
          "execution_count": 20,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Display the information about the categories of the objects present in that\n",
        "# image.\n",
        "coco_json_file['categories']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yZ_n7CwuU6__",
        "outputId": "76724453-54f8-4703-b948-307af371fa83"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "5"
            ]
          },
          "execution_count": 21,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Number of objects which were detected in an image.\n",
        "len(coco_json_file['annotations'])"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
